package com.ai.common.gpt;

import com.ai.common.enums.ConfigModuleEnum;
import com.ai.common.utils.ConfigUtil;
import com.ai.common.utils.TokenUtil;
import com.ai.sevice.ai.remote.CompletionsRequest;
import com.ai.sevice.ai.remote.CompletionsResponse;
import com.ai.sevice.ai.remote.Message;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import javax.annotation.PostConstruct;
import java.math.BigDecimal;
import java.util.Collections;
import java.util.HashMap;

/**
 * @author ：shengwenyang
 * @date ：Created in 2023/4/11 17:40
 * @description：
 * @version:
 */
@Slf4j
@Component
public class OpenAiWebClient {

    private static final String chatGptUrl = "https://api.openai.com/v1/chat/completions";

    private WebClient webClient;

    @Autowired
    private RestTemplate restTemplate;

    @PostConstruct
    public void initProd() {
        log.info("initProd");
        this.webClient = WebClient.builder()
                .defaultHeader(HttpHeaders.CONTENT_TYPE, "application/json")
                .build();
    }

    public Flux<String> getChatResponse(String userId, String prompt, BigDecimal maxTokens, BigDecimal temperature, Double topP) {
        JSONObject params = new JSONObject();

        params.put("model", "gpt-3.5-turbo");
        params.put("max_tokens", maxTokens);
        params.put("stream", true);
        params.put("temperature", temperature);
        if (topP != null) {
            params.put("top_p", topP);
        }
        if (StringUtils.isNotBlank(userId)) {
            params.put("user", userId);
        }
        JSONObject message = new JSONObject();
        message.put("role", "user");
        message.put("content", prompt);
        params.put("messages", Collections.singleton(message));

        Flux<String> stringFlux = webClient.post()
                .uri(ApiConstant.CHAT_API)
                .header(HttpHeaders.AUTHORIZATION, TokenUtil.pollChatGptToken())
                .bodyValue(params.toJSONString())
                .retrieve()
                .bodyToFlux(String.class)
                .onErrorResume(WebClientResponseException.class, ex -> {
                    HttpStatus status = ex.getStatusCode();
                    String res = ex.getResponseBodyAsString();
                    log.error("OpenAI API error: {} {}", status, res);
                    return Mono.error(new RuntimeException(res));
                });

        return stringFlux;

    }

    public CompletionsResponse chatCompletions(String content){
        String maxTokens = ConfigUtil.getInstance().getConfig(ConfigModuleEnum.DEFAULT.getModule(), "maxTokens");
        if (StringUtils.isNotBlank(maxTokens)){
            return chatCompletions(content,new BigDecimal(maxTokens));
        }
        return chatCompletions(content,null);
    }

    public CompletionsResponse chatCompletions(String content,BigDecimal maxTokens){
        HttpHeaders head = new HttpHeaders();
        head.add("Content-Type","application/json");
        head.add("Authorization", TokenUtil.pollChatGptToken());
        CompletionsRequest completionsRequest = new CompletionsRequest();
        completionsRequest.setMax_tokens(maxTokens);
        completionsRequest.setTemperature(new BigDecimal("0.5"));
        completionsRequest.setModel("gpt-3.5-turbo");
        completionsRequest.setFrequency_penalty(new BigDecimal(0));
        completionsRequest.setPresence_penalty(new BigDecimal(0));
        Message message = new Message();
        message.setContent(content);
        message.setRole("user");
        completionsRequest.setMessages(Collections.singletonList(message));
        HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(completionsRequest),head);
        ResponseEntity<CompletionsResponse> responseEntity = restTemplate.postForEntity(chatGptUrl, entity, CompletionsResponse.class,new HashMap<>());
        return responseEntity.getBody();
    }

}
